{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial3: Aggregation\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial we will override the aggregation method of the GIN convolution module of Pytorch Geometric implementing the following methods:\n",
    "\n",
    "- Principal Neighborhood Aggregation (PNA)\n",
    "- Learning Aggregation Functions (LAF)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f5ea0213890>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "torch.manual_seed(42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Message Passing Class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.nn import MessagePassing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['__annotations__',\n",
       " '__call__',\n",
       " '__check_input__',\n",
       " '__class__',\n",
       " '__collect__',\n",
       " '__delattr__',\n",
       " '__dict__',\n",
       " '__dir__',\n",
       " '__doc__',\n",
       " '__eq__',\n",
       " '__format__',\n",
       " '__ge__',\n",
       " '__getattr__',\n",
       " '__getattribute__',\n",
       " '__gt__',\n",
       " '__hash__',\n",
       " '__init__',\n",
       " '__init_subclass__',\n",
       " '__le__',\n",
       " '__lift__',\n",
       " '__lt__',\n",
       " '__module__',\n",
       " '__ne__',\n",
       " '__new__',\n",
       " '__reduce__',\n",
       " '__reduce_ex__',\n",
       " '__repr__',\n",
       " '__set_size__',\n",
       " '__setattr__',\n",
       " '__setstate__',\n",
       " '__sizeof__',\n",
       " '__str__',\n",
       " '__subclasshook__',\n",
       " '__weakref__',\n",
       " '_apply',\n",
       " '_get_name',\n",
       " '_load_from_state_dict',\n",
       " '_named_members',\n",
       " '_register_load_state_dict_pre_hook',\n",
       " '_register_state_dict_hook',\n",
       " '_replicate_for_data_parallel',\n",
       " '_save_to_state_dict',\n",
       " '_slow_forward',\n",
       " '_version',\n",
       " 'add_module',\n",
       " 'aggregate',\n",
       " 'apply',\n",
       " 'bfloat16',\n",
       " 'buffers',\n",
       " 'children',\n",
       " 'cpu',\n",
       " 'cuda',\n",
       " 'double',\n",
       " 'dump_patches',\n",
       " 'eval',\n",
       " 'extra_repr',\n",
       " 'float',\n",
       " 'forward',\n",
       " 'half',\n",
       " 'jittable',\n",
       " 'load_state_dict',\n",
       " 'message',\n",
       " 'message_and_aggregate',\n",
       " 'modules',\n",
       " 'named_buffers',\n",
       " 'named_children',\n",
       " 'named_modules',\n",
       " 'named_parameters',\n",
       " 'parameters',\n",
       " 'propagate',\n",
       " 'register_backward_hook',\n",
       " 'register_buffer',\n",
       " 'register_forward_hook',\n",
       " 'register_forward_pre_hook',\n",
       " 'register_parameter',\n",
       " 'requires_grad_',\n",
       " 'share_memory',\n",
       " 'special_args',\n",
       " 'state_dict',\n",
       " 'to',\n",
       " 'train',\n",
       " 'type',\n",
       " 'update',\n",
       " 'zero_grad']"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dir(MessagePassing)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are interested in the <span style='color:Blue'>aggregate</span> method, or, if you are using a sparse adjacency matrix, in the <span style='color:Blue'>message_and_aggregate</span> method. Convolutional classes in PyG extend MessagePassing, we construct our custom convoutional class extending GINConv."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Scatter operation in <span style='color:Blue'>aggregate</span>:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"https://raw.githubusercontent.com/rusty1s/pytorch_scatter/master/docs/source/_figures/add.svg?sanitize=true\" width=\"500\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.nn import GINConv\n",
    "from torch.nn import Linear\n",
    "from laf_model import LAFLayer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### LAF Aggregation Module"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"laf.png\" width=\"500\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GINLAFConv(GINConv):\n",
    "    def __init__(self, nn, units=1, node_dim=32, **kwargs):\n",
    "        super(GINLAFConv, self).__init__(nn, **kwargs)\n",
    "        self.laf = LAFLayer(units=units, kernel_initializer='random_uniform')\n",
    "        self.mlp = torch.nn.Linear(node_dim*units, node_dim)\n",
    "        self.dim = node_dim\n",
    "        self.units = units\n",
    "    \n",
    "    def aggregate(self, inputs, index):\n",
    "        x = torch.sigmoid(inputs)\n",
    "        x = self.laf(x, index)\n",
    "        x = x.view((-1, self.dim * self.units))\n",
    "        x = self.mlp(x)\n",
    "        return x\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PNA Aggregation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"pna.png\" width=\"800\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GINPNAConv(GINConv):\n",
    "    def __init__(self, nn, node_dim=32, **kwargs):\n",
    "        super(GINPNAConv, self).__init__(nn, **kwargs)\n",
    "        self.mlp = torch.nn.Linear(node_dim*12, node_dim)\n",
    "        self.delta = 2.5749\n",
    "    \n",
    "    def aggregate(self, inputs, index):\n",
    "        sums = torch_scatter.scatter_add(inputs, index, dim=0)\n",
    "        maxs = torch_scatter.scatter_max(inputs, index, dim=0)[0]\n",
    "        means = torch_scatter.scatter_mean(inputs, index, dim=0)\n",
    "        var = torch.relu(torch_scatter.scatter_mean(inputs ** 2, index, dim=0) - means ** 2)\n",
    "        \n",
    "        aggrs = [sums, maxs, means, var]\n",
    "        c_idx = index.bincount().float().view(-1, 1)\n",
    "        l_idx = torch.log(c_idx + 1.)\n",
    "        \n",
    "        amplification_scaler = [c_idx / self.delta * a for a in aggrs]\n",
    "        attenuation_scaler = [self.delta / c_idx * a for a in aggrs]\n",
    "        combinations = torch.cat(aggrs+ amplification_scaler+ attenuation_scaler, dim=1)\n",
    "        x = self.mlp(combinations)\n",
    "    \n",
    "        return x\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test the new classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.nn import MessagePassing, SAGEConv, GINConv, global_add_pool\n",
    "import torch_scatter\n",
    "import torch.nn.functional as F\n",
    "from torch.nn import Sequential, Linear, ReLU\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from torch_geometric.data import DataLoader\n",
    "import os.path as osp\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading https://www.chrsmrrs.com/graphkerneldatasets/MUTAG.zip\n",
      "Extracting data/TU/MUTAG/MUTAG.zip\n",
      "Processing...\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "path = osp.join('./', 'data', 'TU')\n",
    "dataset = TUDataset(path, name='MUTAG').shuffle()\n",
    "test_dataset = dataset[:len(dataset) // 10]\n",
    "train_dataset = dataset[len(dataset) // 10:]\n",
    "test_loader = DataLoader(test_dataset, batch_size=128)\n",
    "train_loader = DataLoader(train_dataset, batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LAFNet(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LAFNet, self).__init__()\n",
    "\n",
    "        num_features = dataset.num_features\n",
    "        dim = 32\n",
    "        units = 3\n",
    "        \n",
    "        nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim))\n",
    "        self.conv1 = GINLAFConv(nn1, units=units, node_dim=num_features)\n",
    "        self.bn1 = torch.nn.BatchNorm1d(dim)\n",
    "\n",
    "        nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))\n",
    "        self.conv2 = GINLAFConv(nn2, units=units, node_dim=dim)\n",
    "        self.bn2 = torch.nn.BatchNorm1d(dim)\n",
    "\n",
    "        nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))\n",
    "        self.conv3 = GINLAFConv(nn3, units=units, node_dim=dim)\n",
    "        self.bn3 = torch.nn.BatchNorm1d(dim)\n",
    "\n",
    "        nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))\n",
    "        self.conv4 = GINLAFConv(nn4, units=units, node_dim=dim)\n",
    "        self.bn4 = torch.nn.BatchNorm1d(dim)\n",
    "\n",
    "        nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))\n",
    "        self.conv5 = GINLAFConv(nn5, units=units, node_dim=dim)\n",
    "        self.bn5 = torch.nn.BatchNorm1d(dim)\n",
    "\n",
    "        self.fc1 = Linear(dim, dim)\n",
    "        self.fc2 = Linear(dim, dataset.num_classes)\n",
    "\n",
    "    def forward(self, x, edge_index, batch):\n",
    "        x = F.relu(self.conv1(x, edge_index))\n",
    "        x = self.bn1(x)\n",
    "        x = F.relu(self.conv2(x, edge_index))\n",
    "        x = self.bn2(x)\n",
    "        x = F.relu(self.conv3(x, edge_index))\n",
    "        x = self.bn3(x)\n",
    "        x = F.relu(self.conv4(x, edge_index))\n",
    "        x = self.bn4(x)\n",
    "        x = F.relu(self.conv5(x, edge_index))\n",
    "        x = self.bn5(x)\n",
    "        x = global_add_pool(x, batch)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.dropout(x, p=0.5, training=self.training)\n",
    "        x = self.fc2(x)\n",
    "        return F.log_softmax(x, dim=-1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PNANet(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(PNANet, self).__init__()\n",
    "\n",
    "        num_features = dataset.num_features\n",
    "        dim = 32\n",
    "\n",
    "        nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim))\n",
    "        self.conv1 = GINPNAConv(nn1, node_dim=num_features)\n",
    "        self.bn1 = torch.nn.BatchNorm1d(dim)\n",
    "\n",
    "        nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))\n",
    "        self.conv2 = GINPNAConv(nn2, node_dim=dim)\n",
    "        self.bn2 = torch.nn.BatchNorm1d(dim)\n",
    "\n",
    "        nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))\n",
    "        self.conv3 = GINPNAConv(nn3, node_dim=dim)\n",
    "        self.bn3 = torch.nn.BatchNorm1d(dim)\n",
    "\n",
    "        nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))\n",
    "        self.conv4 = GINPNAConv(nn4, node_dim=dim)\n",
    "        self.bn4 = torch.nn.BatchNorm1d(dim)\n",
    "\n",
    "        nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))\n",
    "        self.conv5 = GINPNAConv(nn5, node_dim=dim)\n",
    "        self.bn5 = torch.nn.BatchNorm1d(dim)\n",
    "\n",
    "        self.fc1 = Linear(dim, dim)\n",
    "        self.fc2 = Linear(dim, dataset.num_classes)\n",
    "\n",
    "    def forward(self, x, edge_index, batch):\n",
    "        x = F.relu(self.conv1(x, edge_index))\n",
    "        x = self.bn1(x)\n",
    "        x = F.relu(self.conv2(x, edge_index))\n",
    "        x = self.bn2(x)\n",
    "        x = F.relu(self.conv3(x, edge_index))\n",
    "        x = self.bn3(x)\n",
    "        x = F.relu(self.conv4(x, edge_index))\n",
    "        x = self.bn4(x)\n",
    "        x = F.relu(self.conv5(x, edge_index))\n",
    "        x = self.bn5(x)\n",
    "        x = global_add_pool(x, batch)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.dropout(x, p=0.5, training=self.training)\n",
    "        x = self.fc2(x)\n",
    "        return F.log_softmax(x, dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GINNet(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GINNet, self).__init__()\n",
    "\n",
    "        num_features = dataset.num_features\n",
    "        dim = 32\n",
    "\n",
    "        nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim))\n",
    "        self.conv1 = GINConv(nn1)\n",
    "        self.bn1 = torch.nn.BatchNorm1d(dim)\n",
    "\n",
    "        nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))\n",
    "        self.conv2 = GINConv(nn2)\n",
    "        self.bn2 = torch.nn.BatchNorm1d(dim)\n",
    "\n",
    "        nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))\n",
    "        self.conv3 = GINConv(nn3)\n",
    "        self.bn3 = torch.nn.BatchNorm1d(dim)\n",
    "\n",
    "        nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))\n",
    "        self.conv4 = GINConv(nn4)\n",
    "        self.bn4 = torch.nn.BatchNorm1d(dim)\n",
    "\n",
    "        nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))\n",
    "        self.conv5 = GINConv(nn5)\n",
    "        self.bn5 = torch.nn.BatchNorm1d(dim)\n",
    "\n",
    "        self.fc1 = Linear(dim, dim)\n",
    "        self.fc2 = Linear(dim, dataset.num_classes)\n",
    "\n",
    "    def forward(self, x, edge_index, batch):\n",
    "        x = F.relu(self.conv1(x, edge_index))\n",
    "        x = self.bn1(x)\n",
    "        x = F.relu(self.conv2(x, edge_index))\n",
    "        x = self.bn2(x)\n",
    "        x = F.relu(self.conv3(x, edge_index))\n",
    "        x = self.bn3(x)\n",
    "        x = F.relu(self.conv4(x, edge_index))\n",
    "        x = self.bn4(x)\n",
    "        x = F.relu(self.conv5(x, edge_index))\n",
    "        x = self.bn5(x)\n",
    "        x = global_add_pool(x, batch)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.dropout(x, p=0.5, training=self.training)\n",
    "        x = self.fc2(x)\n",
    "        return F.log_softmax(x, dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001, Train Loss: 1.7396854, Train Acc: 0.6705882, Test Acc: 0.6111111\n",
      "Epoch: 002, Train Loss: 0.8858100, Train Acc: 0.6705882, Test Acc: 0.6111111\n",
      "Epoch: 003, Train Loss: 0.7421182, Train Acc: 0.6705882, Test Acc: 0.6111111\n",
      "Epoch: 004, Train Loss: 0.5249412, Train Acc: 0.6705882, Test Acc: 0.6111111\n",
      "Epoch: 005, Train Loss: 0.5756501, Train Acc: 0.6705882, Test Acc: 0.6111111\n",
      "Epoch: 006, Train Loss: 0.5038144, Train Acc: 0.6705882, Test Acc: 0.6111111\n",
      "Epoch: 007, Train Loss: 0.4974038, Train Acc: 0.6705882, Test Acc: 0.6111111\n",
      "Epoch: 008, Train Loss: 0.4389207, Train Acc: 0.6705882, Test Acc: 0.6111111\n",
      "Epoch: 009, Train Loss: 0.4597056, Train Acc: 0.6705882, Test Acc: 0.6111111\n",
      "Epoch: 010, Train Loss: 0.3515818, Train Acc: 0.6705882, Test Acc: 0.6111111\n",
      "Epoch: 011, Train Loss: 0.4357632, Train Acc: 0.6705882, Test Acc: 0.6111111\n",
      "Epoch: 012, Train Loss: 0.3752011, Train Acc: 0.6705882, Test Acc: 0.6111111\n",
      "Epoch: 013, Train Loss: 0.3485447, Train Acc: 0.6705882, Test Acc: 0.6111111\n",
      "Epoch: 014, Train Loss: 0.3208427, Train Acc: 0.6705882, Test Acc: 0.6111111\n",
      "Epoch: 015, Train Loss: 0.3636578, Train Acc: 0.6705882, Test Acc: 0.6111111\n",
      "Epoch: 016, Train Loss: 0.3069694, Train Acc: 0.6705882, Test Acc: 0.6111111\n",
      "Epoch: 017, Train Loss: 0.3395566, Train Acc: 0.6705882, Test Acc: 0.6111111\n",
      "Epoch: 018, Train Loss: 0.2716209, Train Acc: 0.6705882, Test Acc: 0.6111111\n",
      "Epoch: 019, Train Loss: 0.2646066, Train Acc: 0.6823529, Test Acc: 0.6666667\n",
      "Epoch: 020, Train Loss: 0.3058428, Train Acc: 0.7117647, Test Acc: 0.8333333\n",
      "Epoch: 021, Train Loss: 0.2384960, Train Acc: 0.7235294, Test Acc: 0.8333333\n",
      "Epoch: 022, Train Loss: 0.2641419, Train Acc: 0.7411765, Test Acc: 0.7777778\n",
      "Epoch: 023, Train Loss: 0.2356885, Train Acc: 0.7470588, Test Acc: 0.7777778\n",
      "Epoch: 024, Train Loss: 0.2005905, Train Acc: 0.7470588, Test Acc: 0.7777778\n",
      "Epoch: 025, Train Loss: 0.2058828, Train Acc: 0.7588235, Test Acc: 0.7777778\n",
      "Epoch: 026, Train Loss: 0.2233447, Train Acc: 0.7764706, Test Acc: 0.7777778\n",
      "Epoch: 027, Train Loss: 0.1715136, Train Acc: 0.7764706, Test Acc: 0.7777778\n",
      "Epoch: 028, Train Loss: 0.1635957, Train Acc: 0.7882353, Test Acc: 0.7777778\n",
      "Epoch: 029, Train Loss: 0.1446990, Train Acc: 0.8000000, Test Acc: 0.7777778\n",
      "Epoch: 030, Train Loss: 0.1378348, Train Acc: 0.8058824, Test Acc: 0.7777778\n",
      "Epoch: 031, Train Loss: 0.1356434, Train Acc: 0.8411765, Test Acc: 0.8333333\n",
      "Epoch: 032, Train Loss: 0.1354265, Train Acc: 0.8647059, Test Acc: 0.8333333\n",
      "Epoch: 033, Train Loss: 0.1402526, Train Acc: 0.8823529, Test Acc: 0.8333333\n",
      "Epoch: 034, Train Loss: 0.1395012, Train Acc: 0.8823529, Test Acc: 0.7777778\n",
      "Epoch: 035, Train Loss: 0.1279946, Train Acc: 0.8941176, Test Acc: 0.7777778\n",
      "Epoch: 036, Train Loss: 0.1274087, Train Acc: 0.8941176, Test Acc: 0.7777778\n",
      "Epoch: 037, Train Loss: 0.1139296, Train Acc: 0.9058824, Test Acc: 0.7777778\n",
      "Epoch: 038, Train Loss: 0.1023573, Train Acc: 0.9058824, Test Acc: 0.7777778\n",
      "Epoch: 039, Train Loss: 0.0901502, Train Acc: 0.9176471, Test Acc: 0.7777778\n",
      "Epoch: 040, Train Loss: 0.0955645, Train Acc: 0.9000000, Test Acc: 0.7777778\n",
      "Epoch: 041, Train Loss: 0.0966852, Train Acc: 0.9294118, Test Acc: 0.7777778\n",
      "Epoch: 042, Train Loss: 0.0714682, Train Acc: 0.9411765, Test Acc: 0.8333333\n",
      "Epoch: 043, Train Loss: 0.0903990, Train Acc: 0.9352941, Test Acc: 0.8333333\n",
      "Epoch: 044, Train Loss: 0.0789564, Train Acc: 0.9529412, Test Acc: 0.8333333\n",
      "Epoch: 045, Train Loss: 0.0641856, Train Acc: 0.9705882, Test Acc: 0.8333333\n",
      "Epoch: 046, Train Loss: 0.0627783, Train Acc: 0.9529412, Test Acc: 0.8333333\n",
      "Epoch: 047, Train Loss: 0.0673456, Train Acc: 0.9588235, Test Acc: 0.8333333\n",
      "Epoch: 048, Train Loss: 0.0492790, Train Acc: 0.9588235, Test Acc: 0.8888889\n",
      "Epoch: 049, Train Loss: 0.0588718, Train Acc: 0.9588235, Test Acc: 0.8888889\n",
      "Epoch: 050, Train Loss: 0.0616084, Train Acc: 0.9588235, Test Acc: 0.8333333\n",
      "Epoch: 051, Train Loss: 0.0431585, Train Acc: 0.9705882, Test Acc: 0.8333333\n",
      "Epoch: 052, Train Loss: 0.0400881, Train Acc: 0.9647059, Test Acc: 0.8333333\n",
      "Epoch: 053, Train Loss: 0.0389088, Train Acc: 0.9647059, Test Acc: 0.8333333\n",
      "Epoch: 054, Train Loss: 0.0358927, Train Acc: 0.9588235, Test Acc: 0.8333333\n",
      "Epoch: 055, Train Loss: 0.0508927, Train Acc: 0.9588235, Test Acc: 0.8333333\n",
      "Epoch: 056, Train Loss: 0.0362176, Train Acc: 0.9588235, Test Acc: 0.7777778\n",
      "Epoch: 057, Train Loss: 0.0419180, Train Acc: 0.9588235, Test Acc: 0.7777778\n",
      "Epoch: 058, Train Loss: 0.0403139, Train Acc: 0.9705882, Test Acc: 0.7777778\n",
      "Epoch: 059, Train Loss: 0.0279559, Train Acc: 0.9705882, Test Acc: 0.7777778\n",
      "Epoch: 060, Train Loss: 0.0347687, Train Acc: 0.9764706, Test Acc: 0.7222222\n",
      "Epoch: 061, Train Loss: 0.0227536, Train Acc: 0.9705882, Test Acc: 0.7222222\n",
      "Epoch: 062, Train Loss: 0.0429871, Train Acc: 0.9764706, Test Acc: 0.7222222\n",
      "Epoch: 063, Train Loss: 0.0361861, Train Acc: 0.9764706, Test Acc: 0.7222222\n",
      "Epoch: 064, Train Loss: 0.0461259, Train Acc: 0.9764706, Test Acc: 0.7222222\n",
      "Epoch: 065, Train Loss: 0.0225341, Train Acc: 0.9705882, Test Acc: 0.7777778\n",
      "Epoch: 066, Train Loss: 0.0393402, Train Acc: 0.9705882, Test Acc: 0.7777778\n",
      "Epoch: 067, Train Loss: 0.0239863, Train Acc: 0.9764706, Test Acc: 0.7222222\n",
      "Epoch: 068, Train Loss: 0.0257812, Train Acc: 0.9764706, Test Acc: 0.7222222\n",
      "Epoch: 069, Train Loss: 0.0205532, Train Acc: 0.9764706, Test Acc: 0.7222222\n",
      "Epoch: 070, Train Loss: 0.0287459, Train Acc: 0.9764706, Test Acc: 0.7222222\n",
      "Epoch: 071, Train Loss: 0.0143993, Train Acc: 0.9823529, Test Acc: 0.7222222\n",
      "Epoch: 072, Train Loss: 0.0297863, Train Acc: 0.9823529, Test Acc: 0.7222222\n",
      "Epoch: 073, Train Loss: 0.0265774, Train Acc: 0.9764706, Test Acc: 0.7222222\n",
      "Epoch: 074, Train Loss: 0.0142347, Train Acc: 0.9764706, Test Acc: 0.7222222\n",
      "Epoch: 075, Train Loss: 0.0211337, Train Acc: 0.9764706, Test Acc: 0.7222222\n",
      "Epoch: 076, Train Loss: 0.0331199, Train Acc: 0.9764706, Test Acc: 0.7777778\n",
      "Epoch: 077, Train Loss: 0.0315297, Train Acc: 0.9764706, Test Acc: 0.7777778\n",
      "Epoch: 078, Train Loss: 0.0160678, Train Acc: 0.9764706, Test Acc: 0.7777778\n",
      "Epoch: 079, Train Loss: 0.0257950, Train Acc: 0.9764706, Test Acc: 0.7777778\n",
      "Epoch: 080, Train Loss: 0.0248993, Train Acc: 0.9764706, Test Acc: 0.7777778\n",
      "Epoch: 081, Train Loss: 0.0197920, Train Acc: 0.9764706, Test Acc: 0.7777778\n",
      "Epoch: 082, Train Loss: 0.0130890, Train Acc: 0.9705882, Test Acc: 0.7222222\n",
      "Epoch: 083, Train Loss: 0.0220011, Train Acc: 0.9705882, Test Acc: 0.7222222\n",
      "Epoch: 084, Train Loss: 0.0162781, Train Acc: 0.9764706, Test Acc: 0.7222222\n",
      "Epoch: 085, Train Loss: 0.0145022, Train Acc: 0.9764706, Test Acc: 0.7222222\n",
      "Epoch: 086, Train Loss: 0.0093531, Train Acc: 0.9764706, Test Acc: 0.7222222\n",
      "Epoch: 087, Train Loss: 0.0175143, Train Acc: 0.9764706, Test Acc: 0.7222222\n",
      "Epoch: 088, Train Loss: 0.0117858, Train Acc: 0.9764706, Test Acc: 0.7222222\n",
      "Epoch: 089, Train Loss: 0.0176236, Train Acc: 0.9705882, Test Acc: 0.7222222\n",
      "Epoch: 090, Train Loss: 0.0195194, Train Acc: 0.9764706, Test Acc: 0.7222222\n",
      "Epoch: 091, Train Loss: 0.0120206, Train Acc: 0.9764706, Test Acc: 0.7222222\n",
      "Epoch: 092, Train Loss: 0.0136675, Train Acc: 0.9705882, Test Acc: 0.7222222\n",
      "Epoch: 093, Train Loss: 0.0124988, Train Acc: 0.9705882, Test Acc: 0.7222222\n",
      "Epoch: 094, Train Loss: 0.0120447, Train Acc: 0.9705882, Test Acc: 0.7222222\n",
      "Epoch: 095, Train Loss: 0.0187258, Train Acc: 0.9705882, Test Acc: 0.7222222\n",
      "Epoch: 096, Train Loss: 0.0119458, Train Acc: 0.9705882, Test Acc: 0.7222222\n",
      "Epoch: 097, Train Loss: 0.0133186, Train Acc: 0.9705882, Test Acc: 0.7222222\n",
      "Epoch: 098, Train Loss: 0.0128858, Train Acc: 0.9705882, Test Acc: 0.7222222\n",
      "Epoch: 099, Train Loss: 0.0107982, Train Acc: 0.9764706, Test Acc: 0.7222222\n",
      "Epoch: 100, Train Loss: 0.0115271, Train Acc: 0.9764706, Test Acc: 0.7222222\n"
     ]
    }
   ],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "net = \"PNA\"\n",
    "if net == \"LAF\":\n",
    "    model = LAFNet().to(device)\n",
    "elif net == \"PNA\":\n",
    "    model = PNANet().to(device)\n",
    "elif net == \"GIN\":\n",
    "    GINNet().to(device)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "\n",
    "def train(epoch):\n",
    "    model.train()\n",
    "\n",
    "    if epoch == 51:\n",
    "        for param_group in optimizer.param_groups:\n",
    "            param_group['lr'] = 0.5 * param_group['lr']\n",
    "\n",
    "    loss_all = 0\n",
    "    for data in train_loader:\n",
    "        data = data.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data.x, data.edge_index, data.batch)\n",
    "        loss = F.nll_loss(output, data.y)\n",
    "        loss.backward()\n",
    "        loss_all += loss.item() * data.num_graphs\n",
    "        optimizer.step()\n",
    "    return loss_all / len(train_dataset)\n",
    "\n",
    "\n",
    "def test(loader):\n",
    "    model.eval()\n",
    "\n",
    "    correct = 0\n",
    "    for data in loader:\n",
    "        data = data.to(device)\n",
    "        output = model(data.x, data.edge_index, data.batch)\n",
    "        pred = output.max(dim=1)[1]\n",
    "        correct += pred.eq(data.y).sum().item()\n",
    "    return correct / len(loader.dataset)\n",
    "\n",
    "\n",
    "for epoch in range(1, 101):\n",
    "    train_loss = train(epoch)\n",
    "    train_acc = test(train_loader)\n",
    "    test_acc = test(test_loader)\n",
    "    print('Epoch: {:03d}, Train Loss: {:.7f}, '\n",
    "          'Train Acc: {:.7f}, Test Acc: {:.7f}'.format(epoch, train_loss,\n",
    "                                                       train_acc, test_acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
